5.4 自定义层

要点
  • 自定义参数要善于使用 nn.Parameter

1. 不带参数的层

import torch
import torch.nn.functional as F
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        return X - X.mean()

2. 带参数的层

下面自定义实现了一个全连接层,支持自定义参数矩阵:

class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units,))
    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)
提示

nn.Parameter 来自定义参数,这样可以保留梯度,不要直接 self.weight = torch.randn(in_units, units)

参考文献



© 2023 yanghn. All rights reserved. Powered by Obsidian